/**
 *  LLMEndpoint
 *  Copyright 2024 by Michael Peter Christen
 *  First released 17.05.2024 at https://yacy.net
 *
 *  This library is free software; you can redistribute it and/or
 *  modify it under the terms of the GNU Lesser General Public
 *  License as published by the Free Software Foundation; either
 *  version 2.1 of the License, or (at your option) any later version.
 *
 *  This library is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 *  Lesser General Public License for more details.
 *
 *  You should have received a copy of the GNU Lesser General Public License
 *  along with this program in the file lgpl21.txt
 *  If not, see <http://www.gnu.org/licenses/>.
 */

package net.yacy.ai;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;
import org.json.JSONTokener;

import net.yacy.search.Switchboard;

public class LLM {

    private static String[] STOPTOKENS = new String[]{"[/INST]", "<|im_end|>", "<|end_of_turn|>", "<|eot_id|>", "<|end_header_id|>", "<EOS_TOKEN>", "</s>", "<|end|>"};

    public static enum LLMType {
        OPENAI("https://api.openai.com"),
        OLLAMA("http://localhost:11434"),
        LMSTUDIO("http://localhost:1234"),
        OPENROUTER("https://openrouter.ai/api"),
        OTHER(null);
        public String hoststub;
        private LLMType(String hoststub) {
            this.hoststub = hoststub;
        }
    }
    
    public static enum LLMUsage {
        search,
        chat,
        translation,
        classification,
        query,
        qapairs,
        tldr
    }
    
    public static class LLMModel {
        public LLM llm;
        public String model;
        public LLMModel(LLM llm, String model) {
            this.llm = llm;
            this.model = model;
        }
    }
    
    public final String hoststub;
    public final String api_key;
    public final int max_tokens; // the max_tokens as configured by the endpoint for all models
    public final LLMType type;
    
    public LLM(final String hoststub, final String api_key, final int max_tokens, final LLMType type) {
        this.hoststub = hoststub.endsWith("/") ? hoststub.substring(0, hoststub.length() - 1) : hoststub;
        this.api_key = api_key == null ? "" : api_key;
        this.max_tokens = max_tokens <= 0 ? 4096 : max_tokens;
        this.type = type;
    }
    
    /**
     * The following function picks up the right model that was configured in the LLMSelection.
     * @param llmUsage
     * @return
     */
    public static LLMModel llmFromUsage(LLMUsage llmUsage) {
        Switchboard sb = Switchboard.getSwitchboard();
        String pms = sb.getConfig("ai.production_models", "[]");
        try {
            JSONArray production_models = new JSONArray(new JSONTokener(pms));
            // got through all the selected models to find which one has the wanted usage flag switched on
            for (int i = 0; i < production_models.length(); i++) {
                JSONObject row = production_models.getJSONObject(i);
                boolean switched_on = row.optBoolean(llmUsage.name(), false);
                if (switched_on) {
                    // found one that shall be used for this use case
                    final String hoststub = row.optString("hoststub", "");
                    final String api_key = row.optString("api_key", "");
                    final int max_tokens = Integer.parseInt(row.optString("max_tokens", "4096"));
                    final String model = row.optString("model", "");
                    final LLMType type = LLMType.valueOf(row.optString("service", "OLLAMA"));
                    LLM llm = new LLM(hoststub, api_key, max_tokens, type);
                    LLMModel llmmodel = new LLMModel(llm, model);
                    return llmmodel;
                }
            }
        } catch (JSONException | NumberFormatException e) {
            e.printStackTrace();
        }
        // so if we don't find a model for that specific usage, we purposely return null to show that there is a missing configuration
        return null;
    }    
    
    public String getHoststub() {
		return this.hoststub;
	}


    // API Helper Methods

    private static String sendPostRequest(final String urls, final JSONObject data) throws IOException, URISyntaxException {
        final URL url = new URI(urls).toURL();
        final HttpURLConnection conn = (HttpURLConnection) url.openConnection();
        conn.setRequestMethod("POST");
        conn.setRequestProperty("Content-Type", "application/json");
        conn.setDoOutput(true);

        try (OutputStream os = conn.getOutputStream()) {
            final byte[] input = data.toString().getBytes("utf-8");
            os.write(input, 0, input.length);
        }

        final int responseCode = conn.getResponseCode();
        if (responseCode == HttpURLConnection.HTTP_OK) {
            try (BufferedReader br = new BufferedReader(new InputStreamReader(conn.getInputStream(), "utf-8"))) {
                final StringBuilder response = new StringBuilder();
                String responseLine;
                while ((responseLine = br.readLine()) != null) {
                    response.append(responseLine.trim());
                }
                return response.toString();
            }
        } else {
            throw new IOException("Request failed with response code " + responseCode);
        }
    }

    private static String sendGetRequest(final String urls) throws IOException, URISyntaxException {
        final URL url = new URI(urls).toURL();
        final HttpURLConnection conn = (HttpURLConnection) url.openConnection();
        conn.setRequestMethod("GET");

        final int responseCode = conn.getResponseCode();
        if (responseCode == HttpURLConnection.HTTP_OK) {
            try (BufferedReader br = new BufferedReader(new InputStreamReader(conn.getInputStream(), "utf-8"))) {
                final StringBuilder response = new StringBuilder();
                String responseLine;
                while ((responseLine = br.readLine()) != null) {
                    response.append(responseLine.trim());
                }
                return response.toString();
            }
        } else {
            throw new IOException("Request failed with response code " + responseCode);
        }
    }
    
    
    public LinkedHashMap<String, Long> listOllamaModels() {
        final LinkedHashMap<String, Long> sortedMap = new LinkedHashMap<>();
        try {
            final String response = sendGetRequest(this.hoststub + "/api/tags");
            final JSONObject responseObject = new JSONObject(response);
            final JSONArray models = responseObject.getJSONArray("models");

            final List<Map.Entry<String, Long>> list = new ArrayList<>();
            for (int i = 0; i < models.length(); i++) {
                final JSONObject model = models.getJSONObject(i);
                final String name = model.optString("name", "");
                final long size = model.optLong("size", 0);
                list.add(new AbstractMap.SimpleEntry<>(name, size));
            }

            // Sort the list in descending order based on the values
            list.sort((o1, o2) -> o2.getValue().compareTo(o1.getValue()));

            // Create a new LinkedHashMap and add the sorted entries
            for (final Map.Entry<String, Long> entry : list) {
                sortedMap.put(entry.getKey(), entry.getValue());
            }
        } catch (JSONException | URISyntaxException | IOException e) {
            e.printStackTrace();
        }
        return sortedMap;
    }

    public boolean ollamaModelExists(final String name) {
        final JSONObject data = new JSONObject();
        try {
            data.put("name", name);
            sendPostRequest(this.hoststub + "/api/show", data);
            return true;
        } catch (JSONException | URISyntaxException | IOException e) {
            return false;
        }
    }

    public boolean pullOllamaModel(final String name) {
        final JSONObject data = new JSONObject();
        try {
            data.put("name", name);
            data.put("stream", false);
            final String response = sendPostRequest(this.hoststub + "/api/pull", data);
            // this sends {"status": "success"} in case of success
            final JSONObject responseObject = new JSONObject(response);
            final String status = responseObject.optString("status", "");
            return status.equals("success");
        } catch (JSONException | URISyntaxException | IOException e) {
            return false;
        }
    }
    
    // chat endpoints
    
    public static class Context extends JSONArray {
        public Context(String systemPrompt) throws JSONException {
            super();
            final JSONObject systemPromptObject = new JSONObject(true);
            systemPromptObject.put("role", "system");
            systemPromptObject.put("content", systemPrompt);
            this.put(systemPromptObject);
        }
        public void addDialog(String user, String assistant) throws JSONException {
            final JSONObject userPromptObject = new JSONObject(true);
            userPromptObject.put("role", "user");
            userPromptObject.put("content", user);
            this.put(userPromptObject);
            final JSONObject assistantPromptObject = new JSONObject(true);
            assistantPromptObject.put("role", "assistant");
            assistantPromptObject.put("content", assistant);
            this.put(assistantPromptObject);
        }
        public void addPrompt(String userPrompt) throws JSONException {
            final JSONObject userPromptObject = new JSONObject(true);
            userPromptObject.put("role", "user");
            userPromptObject.put("content", userPrompt);
            this.put(userPromptObject);
        }
    }

    // OpenAI chat client, works with llama.cpp and Ollama
    public String chat(final String model, final Context context, JSONObject schema, final int max_tokens) throws IOException {
        final JSONObject data = new JSONObject();
        
        try {
            data.put("model", model);
            data.put("temperature", 0.1);
            data.put("max_tokens", max_tokens);
            data.put("messages", context);
            data.put("stop", new JSONArray(STOPTOKENS));
            data.put("stream", false);

            if (schema != null) {
                System.out.println(schema.toString());
                JSONObject json_schema = new JSONObject(true);
                json_schema.put("strict", true);
                json_schema.put("schema", schema);
                JSONObject response_format = new JSONObject();
                response_format.put("type", "json_schema");
                response_format.put("json_schema", json_schema);            
                data.put("response_format", response_format);
            }
            
            final String response = sendPostRequest(this.hoststub + "/v1/chat/completions", data);
            final JSONObject responseObject = new JSONObject(response);
            final JSONArray choices = responseObject.getJSONArray("choices");
            final JSONObject choice = choices.getJSONObject(0);
            final JSONObject message = choice.getJSONObject("message");
            final String content = message.optString("content", "");
            return content;
        } catch (JSONException | URISyntaxException e) {
            throw new IOException(e.getMessage());
        }
    }
    
    public String chat(final String model, final String systemPrompt, final String userPrompt, final int max_tokens) throws IOException {
        try {
            Context context = new Context(systemPrompt);
            context.addPrompt(userPrompt);
            return chat(model, context, null, max_tokens);
        } catch (JSONException e) {
            throw new IOException(e.getMessage());
        }
    }
    
    public static String[] stringsFromChat(String chatanswer) throws JSONException {
        JSONArray ja = new JSONArray(chatanswer);
        List<String> list = new ArrayList<>();
        // parse the JSON array and extract strings
        for (int i = 0; i < ja.length(); i++) {
            Object item = ja.get(i);
            if (item instanceof String) {
                list.add((String) item);
            } else if (item instanceof JSONObject) {
                JSONObject jo = (JSONObject) item;
                String answer = jo.optString("answer", null);
                if (answer != null) {
                    list.add(answer);
                } else {
                    // take any string value from the object
                    for (String key : jo.keySet()) {
                        Object value = jo.optString(key, null);
                        if (value != null && value instanceof String) {
                            list.add((String) value);
                            break; // take the first string found
                        }
                    }
                }
            }
        }
        // convert the list to an array
        String[] result = new String[list.size()];
        return list.toArray(result);        
    }
    
    public final static JSONObject listSchema = new JSONObject(Map.of(
        "title", "Answer List",
        "type", "array",
        "properties", Map.of(
            "answer", Map.of("type", "string")
        ),
        "required", List.of("answer")
    ));
    
    public static void main(final String[] args) {
        final LLM llm = new LLM(LLMType.OLLAMA.hoststub, null, 4069, LLMType.OLLAMA);

        final LinkedHashMap<String, Long> models = llm.listOllamaModels();
        System.out.println(models.toString());

        // check if model exists
        final String model = "qwen2.5:0.5b";
        if (llm.ollamaModelExists(model))
            System.out.println("model " + model + " exists");
        else
            System.out.println("model " + model + " does not exist");

        // pull a model
        final boolean success = llm.pullOllamaModel(model);
        System.out.println("pulled model: " + model + ": " + success);
        
        String response;
		try {
			response = llm.chat(model, "You are a helpful assistant.", "What is the capital of France?", 1000);
	        System.out.println("Chat response: " + response);
		} catch (IOException e) {
	
			e.printStackTrace();
		}

        // make chat completion with model
        String question = "Who invented the wheel?";
        try {
            final String answer = llm.chat(model, "Make short answers.", question, 200);
            System.out.println(answer);
        } catch (final IOException e) {
            e.printStackTrace();
        }

        // try the json parser from chat results
        question = "Make a list of four names from Star Wars movies. Use a JSON Array.";
        try {
            Context context = new Context("Make short answers");
            context.addPrompt(question);
            final String[] a = stringsFromChat(llm.chat(model, context, listSchema, 1000));
            for (String s : a) {
                System.out.println(s);
            }
        } catch (final IOException | JSONException e) {
            e.printStackTrace();
        }
    }
    
}
